// SPDX-FileCopyrightText: © 2025 Tenstorrent AI ULC
//
// SPDX-License-Identifier: Apache-2.0
#include <gtest/gtest.h>

#include <core/ttnn_all_includes.hpp>

#include "autograd/tensor.hpp"
#include "core/tt_tensor_utils.hpp"
#include "core/xtensor_utils.hpp"
#include "modules/positional_embeddings.hpp"
#include "modules/rotary_embedding.hpp"
#include "ops/losses.hpp"

namespace ttml::modules::tests {

class RoPETest : public ::testing::Test {
protected:
    void SetUp() override {
        ttml::autograd::ctx().open_device();
    }

    void TearDown() override {
        ttml::autograd::ctx().close_device();
    }
};

TEST_F(RoPETest, GeneratedParamsOk) {
    xt::xarray<float> expected_cos = {
        {{{1.00000F, 1.00000F, 1.00000F, 1.00000F, 1.00000F, 1.00000F, 1.00000F, 1.00000F, 1.00000F, 1.00000F, 1.00000F,
           1.00000F, 1.00000F, 1.00000F, 1.00000F, 1.00000F, 1.00000F, 1.00000F, 1.00000F, 1.00000F, 1.00000F, 1.00000F,
           1.00000F, 1.00000F, 1.00000F, 1.00000F, 1.00000F, 1.00000F, 1.00000F, 1.00000F, 1.00000F, 1.00000F},
          {0.54030F, 0.54030F, 0.84601F, 0.84601F, 0.95042F, 0.95042F, 0.98423F, 0.98423F, 0.99500F, 0.99500F, 0.99842F,
           0.99842F, 0.99950F, 0.99950F, 0.99984F, 0.99984F, 0.99995F, 0.99995F, 0.99998F, 0.99998F, 0.99999F, 0.99999F,
           1.00000F, 1.00000F, 1.00000F, 1.00000F, 1.00000F, 1.00000F, 1.00000F, 1.00000F, 1.00000F, 1.00000F},
          {-0.41615F, -0.41615F, 0.43146F, 0.43146F, 0.80658F, 0.80658F, 0.93742F, 0.93742F,
           0.98007F,  0.98007F,  0.99368F, 0.99368F, 0.99800F, 0.99800F, 0.99937F, 0.99937F,
           0.99980F,  0.99980F,  0.99994F, 0.99994F, 0.99998F, 0.99998F, 0.99999F, 0.99999F,
           1.00000F,  1.00000F,  1.00000F, 1.00000F, 1.00000F, 1.00000F, 1.00000F, 1.00000F},
          {-0.98999F, -0.98999F, -0.11597F, -0.11597F, 0.58275F, 0.58275F, 0.86104F, 0.86104F,
           0.95534F,  0.95534F,  0.98580F,  0.98580F,  0.99550F, 0.99550F, 0.99858F, 0.99858F,
           0.99955F,  0.99955F,  0.99986F,  0.99986F,  0.99995F, 0.99995F, 0.99999F, 0.99999F,
           1.00000F,  1.00000F,  1.00000F,  1.00000F,  1.00000F, 1.00000F, 1.00000F, 1.00000F},
          {-0.65364F, -0.65364F, -0.62768F, -0.62768F, 0.30114F, 0.30114F, 0.75751F, 0.75751F,
           0.92106F,  0.92106F,  0.97481F,  0.97481F,  0.99201F, 0.99201F, 0.99747F, 0.99747F,
           0.99920F,  0.99920F,  0.99975F,  0.99975F,  0.99992F, 0.99992F, 0.99997F, 0.99997F,
           0.99999F,  0.99999F,  1.00000F,  1.00000F,  1.00000F, 1.00000F, 1.00000F, 1.00000F},
          {0.28366F, 0.28366F, -0.94608F, -0.94608F, -0.01034F, -0.01034F, 0.63008F, 0.63008F,
           0.87758F, 0.87758F, 0.96073F,  0.96073F,  0.98753F,  0.98753F,  0.99605F, 0.99605F,
           0.99875F, 0.99875F, 0.99960F,  0.99960F,  0.99988F,  0.99988F,  0.99996F, 0.99996F,
           0.99999F, 0.99999F, 1.00000F,  1.00000F,  1.00000F,  1.00000F,  1.00000F, 1.00000F},
          {0.96017F, 0.96017F, -0.97310F, -0.97310F, -0.32080F, -0.32080F, 0.48278F, 0.48278F,
           0.82534F, 0.82534F, 0.94362F,  0.94362F,  0.98205F,  0.98205F,  0.99431F, 0.99431F,
           0.99820F, 0.99820F, 0.99943F,  0.99943F,  0.99982F,  0.99982F,  0.99994F, 0.99994F,
           0.99998F, 0.99998F, 0.99999F,  0.99999F,  1.00000F,  1.00000F,  1.00000F, 1.00000F},
          {0.75390F, 0.75390F, -0.70043F, -0.70043F, -0.59944F, -0.59944F, 0.32026F, 0.32026F,
           0.76484F, 0.76484F, 0.92352F,  0.92352F,  0.97560F,  0.97560F,  0.99226F, 0.99226F,
           0.99755F, 0.99755F, 0.99923F,  0.99923F,  0.99976F,  0.99976F,  0.99992F, 0.99992F,
           0.99998F, 0.99998F, 0.99999F,  0.99999F,  1.00000F,  1.00000F,  1.00000F, 1.00000F},
          {-0.14550F, -0.14550F, -0.21204F, -0.21204F, -0.81863F, -0.81863F, 0.14763F, 0.14763F,
           0.69671F,  0.69671F,  0.90050F,  0.90050F,  0.96817F,  0.96817F,  0.98990F, 0.98990F,
           0.99680F,  0.99680F,  0.99899F,  0.99899F,  0.99968F,  0.99968F,  0.99990F, 0.99990F,
           0.99997F,  0.99997F,  0.99999F,  0.99999F,  1.00000F,  1.00000F,  1.00000F, 1.00000F},
          {-0.91113F, -0.91113F, 0.34166F, 0.34166F, -0.95664F, -0.95664F, -0.02965F, -0.02965F,
           0.62161F,  0.62161F,  0.87464F, 0.87464F, 0.95977F,  0.95977F,  0.98722F,  0.98722F,
           0.99595F,  0.99595F,  0.99872F, 0.99872F, 0.99960F,  0.99960F,  0.99987F,  0.99987F,
           0.99996F,  0.99996F,  0.99999F, 0.99999F, 1.00000F,  1.00000F,  1.00000F,  1.00000F},
          {-0.83907F, -0.83907F, 0.79013F, 0.79013F, -0.99979F, -0.99979F, -0.20600F, -0.20600F,
           0.54030F,  0.54030F,  0.84601F, 0.84601F, 0.95042F,  0.95042F,  0.98423F,  0.98423F,
           0.99500F,  0.99500F,  0.99842F, 0.99842F, 0.99950F,  0.99950F,  0.99984F,  0.99984F,
           0.99995F,  0.99995F,  0.99998F, 0.99998F, 0.99999F,  0.99999F,  1.00000F,  1.00000F},
          {0.00443F, 0.00443F, 0.99526F, 0.99526F, -0.94378F, -0.94378F, -0.37585F, -0.37585F,
           0.45360F, 0.45360F, 0.81471F, 0.81471F, 0.94011F,  0.94011F,  0.98093F,  0.98093F,
           0.99396F, 0.99396F, 0.99809F, 0.99809F, 0.99940F,  0.99940F,  0.99981F,  0.99981F,
           0.99994F, 0.99994F, 0.99998F, 0.99998F, 0.99999F,  0.99999F,  1.00000F,  1.00000F},
          {0.84385F, 0.84385F, 0.89386F, 0.89386F, -0.79418F, -0.79418F, -0.53384F, -0.53384F,
           0.36236F, 0.36236F, 0.78083F, 0.78083F, 0.92886F,  0.92886F,  0.97732F,  0.97732F,
           0.99281F, 0.99281F, 0.99772F, 0.99772F, 0.99928F,  0.99928F,  0.99977F,  0.99977F,
           0.99993F, 0.99993F, 0.99998F, 0.99998F, 0.99999F,  0.99999F,  1.00000F,  1.00000F},
          {0.90745F, 0.90745F, 0.51717F, 0.51717F, -0.56582F, -0.56582F, -0.67500F, -0.67500F,
           0.26750F, 0.26750F, 0.74448F, 0.74448F, 0.91668F,  0.91668F,  0.97340F,  0.97340F,
           0.99156F, 0.99156F, 0.99733F, 0.99733F, 0.99916F,  0.99916F,  0.99973F,  0.99973F,
           0.99992F, 0.99992F, 0.99997F, 0.99997F, 0.99999F,  0.99999F,  1.00000F,  1.00000F},
          {0.13674F, 0.13674F, -0.01880F, -0.01880F, -0.28135F, -0.28135F, -0.79487F, -0.79487F,
           0.16997F, 0.16997F, 0.70578F,  0.70578F,  0.90359F,  0.90359F,  0.96917F,  0.96917F,
           0.99022F, 0.99022F, 0.99690F,  0.99690F,  0.99902F,  0.99902F,  0.99969F,  0.99969F,
           0.99990F, 0.99990F, 0.99997F,  0.99997F,  0.99999F,  0.99999F,  1.00000F,  1.00000F},
          {-0.75969F, -0.75969F, -0.54898F, -0.54898F, 0.03102F, 0.03102F, -0.88967F, -0.88967F,
           0.07074F,  0.07074F,  0.66484F,  0.66484F,  0.88959F, 0.88959F, 0.96463F,  0.96463F,
           0.98877F,  0.98877F,  0.99644F,  0.99644F,  0.99888F, 0.99888F, 0.99964F,  0.99964F,
           0.99989F,  0.99989F,  0.99996F,  0.99996F,  0.99999F, 0.99999F, 1.00000F,  1.00000F},
          {-0.95766F, -0.95766F, -0.91008F, -0.91008F, 0.34032F, 0.34032F, -0.95641F, -0.95641F,
           -0.02920F, -0.02920F, 0.62181F,  0.62181F,  0.87471F, 0.87471F, 0.95980F,  0.95980F,
           0.98723F,  0.98723F,  0.99595F,  0.99595F,  0.99872F, 0.99872F, 0.99960F,  0.99960F,
           0.99987F,  0.99987F,  0.99996F,  0.99996F,  0.99999F, 0.99999F, 1.00000F,  1.00000F},
          {-0.27516F, -0.27516F, -0.99090F, -0.99090F, 0.61586F, 0.61586F, -0.99298F, -0.99298F,
           -0.12884F, -0.12884F, 0.57681F,  0.57681F,  0.85895F, 0.85895F, 0.95465F,  0.95465F,
           0.98558F,  0.98558F,  0.99543F,  0.99543F,  0.99856F, 0.99856F, 0.99954F,  0.99954F,
           0.99986F,  0.99986F,  0.99995F,  0.99995F,  0.99999F, 0.99999F, 1.00000F,  1.00000F},
          {0.66032F,  0.66032F,  -0.76654F, -0.76654F, 0.83034F, 0.83034F, -0.99824F, -0.99824F,
           -0.22720F, -0.22720F, 0.52998F,  0.52998F,  0.84233F, 0.84233F, 0.94921F,  0.94921F,
           0.98384F,  0.98384F,  0.99488F,  0.99488F,  0.99838F, 0.99838F, 0.99949F,  0.99949F,
           0.99984F,  0.99984F,  0.99995F,  0.99995F,  0.99998F, 0.99998F, 0.99999F,  0.99999F},
          {0.98870F,  0.98870F,  -0.30610F, -0.30610F, 0.96246F, 0.96246F, -0.97201F, -0.97201F,
           -0.32329F, -0.32329F, 0.48148F,  0.48148F,  0.82487F, 0.82487F, 0.94346F,  0.94346F,
           0.98200F,  0.98200F,  0.99430F,  0.99430F,  0.99820F, 0.99820F, 0.99943F,  0.99943F,
           0.99982F,  0.99982F,  0.99994F,  0.99994F,  0.99998F, 0.99998F, 0.99999F,  0.99999F},
          {0.40808F,  0.40808F,  0.24862F, 0.24862F, 0.99914F, 0.99914F, -0.91513F, -0.91513F,
           -0.41615F, -0.41615F, 0.43146F, 0.43146F, 0.80658F, 0.80658F, 0.93742F,  0.93742F,
           0.98007F,  0.98007F,  0.99368F, 0.99368F, 0.99800F, 0.99800F, 0.99937F,  0.99937F,
           0.99980F,  0.99980F,  0.99994F, 0.99994F, 0.99998F, 0.99998F, 0.99999F,  0.99999F},
          {-0.54773F, -0.54773F, 0.72676F, 0.72676F, 0.93674F, 0.93674F, -0.82938F, -0.82938F,
           -0.50485F, -0.50485F, 0.38008F, 0.38008F, 0.78749F, 0.78749F, 0.93108F,  0.93108F,
           0.97803F,  0.97803F,  0.99304F, 0.99304F, 0.99780F, 0.99780F, 0.99930F,  0.99930F,
           0.99978F,  0.99978F,  0.99993F, 0.99993F, 0.99998F, 0.99998F, 0.99999F,  0.99999F},
          {-0.99996F, -0.99996F, 0.98107F, 0.98107F, 0.78144F, 0.78144F, -0.71748F, -0.71748F,
           -0.58850F, -0.58850F, 0.32749F, 0.32749F, 0.76760F, 0.76760F, 0.92444F,  0.92444F,
           0.97590F,  0.97590F,  0.99236F, 0.99236F, 0.99758F, 0.99758F, 0.99923F,  0.99923F,
           0.99976F,  0.99976F,  0.99992F, 0.99992F, 0.99998F, 0.99998F, 0.99999F,  0.99999F},
          {-0.53283F, -0.53283F, 0.93324F, 0.93324F, 0.54865F, 0.54865F, -0.58294F, -0.58294F,
           -0.66628F, -0.66628F, 0.27387F, 0.27387F, 0.74696F, 0.74696F, 0.91752F,  0.91752F,
           0.97367F,  0.97367F,  0.99165F, 0.99165F, 0.99736F, 0.99736F, 0.99916F,  0.99916F,
           0.99974F,  0.99974F,  0.99992F, 0.99992F, 0.99997F, 0.99997F, 0.99999F,  0.99999F},
          {0.42418F,  0.42418F,  0.59798F, 0.59798F, 0.26144F, 0.26144F, -0.43002F, -0.43002F,
           -0.73739F, -0.73739F, 0.21938F, 0.21938F, 0.72556F, 0.72556F, 0.91030F,  0.91030F,
           0.97134F,  0.97134F,  0.99091F, 0.99091F, 0.99712F, 0.99712F, 0.99909F,  0.99909F,
           0.99971F,  0.99971F,  0.99991F, 0.99991F, 0.99997F, 0.99997F, 0.99999F,  0.99999F},
          {0.99120F,  0.99120F,  0.07855F, 0.07855F, -0.05169F, -0.05169F, -0.26354F, -0.26354F,
           -0.80114F, -0.80114F, 0.16420F, 0.16420F, 0.70344F,  0.70344F,  0.90280F,  0.90280F,
           0.96891F,  0.96891F,  0.99013F, 0.99013F, 0.99688F,  0.99688F,  0.99901F,  0.99901F,
           0.99969F,  0.99969F,  0.99990F, 0.99990F, 0.99997F,  0.99997F,  0.99999F,  0.99999F},
          {0.64692F,  0.64692F,  -0.46506F, -0.46506F, -0.35969F, -0.35969F, -0.08875F, -0.08875F,
           -0.85689F, -0.85689F, 0.10849F,  0.10849F,  0.68062F,  0.68062F,  0.89501F,  0.89501F,
           0.96639F,  0.96639F,  0.98933F,  0.98933F,  0.99662F,  0.99662F,  0.99893F,  0.99893F,
           0.99966F,  0.99966F,  0.99989F,  0.99989F,  0.99997F,  0.99997F,  0.99999F,  0.99999F},
          {-0.29214F, -0.29214F, -0.86545F, -0.86545F, -0.63203F, -0.63203F, 0.08885F, 0.08885F,
           -0.90407F, -0.90407F, 0.05245F,  0.05245F,  0.65711F,  0.65711F,  0.88693F, 0.88693F,
           0.96377F,  0.96377F,  0.98850F,  0.98850F,  0.99636F,  0.99636F,  0.99885F, 0.99885F,
           0.99964F,  0.99964F,  0.99988F,  0.99988F,  0.99996F,  0.99996F,  0.99999F, 0.99999F},
          {-0.96261F, -0.96261F, -0.99929F, -0.99929F, -0.84168F, -0.84168F, 0.26364F, 0.26364F,
           -0.94222F, -0.94222F, -0.00376F, -0.00376F, 0.63295F,  0.63295F,  0.87858F, 0.87858F,
           0.96106F,  0.96106F,  0.98763F,  0.98763F,  0.99608F,  0.99608F,  0.99876F, 0.99876F,
           0.99961F,  0.99961F,  0.99988F,  0.99988F,  0.99996F,  0.99996F,  0.99999F, 0.99999F},
          {-0.74806F, -0.74806F, -0.82537F, -0.82537F, -0.96787F, -0.96787F, 0.43012F, 0.43012F,
           -0.97096F, -0.97096F, -0.05996F, -0.05996F, 0.60816F,  0.60816F,  0.86995F, 0.86995F,
           0.95824F,  0.95824F,  0.98673F,  0.98673F,  0.99580F,  0.99580F,  0.99867F, 0.99867F,
           0.99958F,  0.99958F,  0.99987F,  0.99987F,  0.99996F,  0.99996F,  0.99999F, 0.99999F},
          {0.15425F,  0.15425F,  -0.39725F, -0.39725F, -0.99808F, -0.99808F, 0.58303F, 0.58303F,
           -0.98999F, -0.98999F, -0.11597F, -0.11597F, 0.58275F,  0.58275F,  0.86104F, 0.86104F,
           0.95534F,  0.95534F,  0.98580F,  0.98580F,  0.99550F,  0.99550F,  0.99858F, 0.99858F,
           0.99955F,  0.99955F,  0.99986F,  0.99986F,  0.99995F,  0.99995F,  0.99999F, 0.99999F},
          {0.91474F,  0.91474F,  0.15322F,  0.15322F,  -0.92930F, -0.92930F, 0.71755F, 0.71755F,
           -0.99914F, -0.99914F, -0.17161F, -0.17161F, 0.55677F,  0.55677F,  0.85186F, 0.85186F,
           0.95233F,  0.95233F,  0.98484F,  0.98484F,  0.99520F,  0.99520F,  0.99848F, 0.99848F,
           0.99952F,  0.99952F,  0.99985F,  0.99985F,  0.99995F,  0.99995F,  0.99998F, 0.99998F}}}};
    xt::xarray<float> expected_sin = {
        {{{0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F,
           0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F,
           0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F},
          {0.84147F, 0.84147F, 0.53317F, 0.53317F, 0.31098F, 0.31098F, 0.17689F, 0.17689F, 0.09983F, 0.09983F, 0.05620F,
           0.05620F, 0.03162F, 0.03162F, 0.01778F, 0.01778F, 0.01000F, 0.01000F, 0.00562F, 0.00562F, 0.00316F, 0.00316F,
           0.00178F, 0.00178F, 0.00100F, 0.00100F, 0.00056F, 0.00056F, 0.00032F, 0.00032F, 0.00018F, 0.00018F},
          {0.90930F, 0.90930F, 0.90213F, 0.90213F, 0.59113F, 0.59113F, 0.34821F, 0.34821F, 0.19867F, 0.19867F, 0.11223F,
           0.11223F, 0.06320F, 0.06320F, 0.03556F, 0.03556F, 0.02000F, 0.02000F, 0.01125F, 0.01125F, 0.00632F, 0.00632F,
           0.00356F, 0.00356F, 0.00200F, 0.00200F, 0.00112F, 0.00112F, 0.00063F, 0.00063F, 0.00036F, 0.00036F},
          {0.14112F, 0.14112F, 0.99325F, 0.99325F, 0.81265F, 0.81265F, 0.50854F, 0.50854F, 0.29552F, 0.29552F, 0.16790F,
           0.16790F, 0.09473F, 0.09473F, 0.05332F, 0.05332F, 0.03000F, 0.03000F, 0.01687F, 0.01687F, 0.00949F, 0.00949F,
           0.00533F, 0.00533F, 0.00300F, 0.00300F, 0.00169F, 0.00169F, 0.00095F, 0.00095F, 0.00053F, 0.00053F},
          {-0.75680F, -0.75680F, 0.77847F, 0.77847F, 0.95358F, 0.95358F, 0.65283F, 0.65283F,
           0.38942F,  0.38942F,  0.22304F, 0.22304F, 0.12615F, 0.12615F, 0.07107F, 0.07107F,
           0.03999F,  0.03999F,  0.02249F, 0.02249F, 0.01265F, 0.01265F, 0.00711F, 0.00711F,
           0.00400F,  0.00400F,  0.00225F, 0.00225F, 0.00126F, 0.00126F, 0.00071F, 0.00071F},
          {-0.95892F, -0.95892F, 0.32394F, 0.32394F, 0.99995F, 0.99995F, 0.77653F, 0.77653F,
           0.47943F,  0.47943F,  0.27748F, 0.27748F, 0.15746F, 0.15746F, 0.08880F, 0.08880F,
           0.04998F,  0.04998F,  0.02811F, 0.02811F, 0.01581F, 0.01581F, 0.00889F, 0.00889F,
           0.00500F,  0.00500F,  0.00281F, 0.00281F, 0.00158F, 0.00158F, 0.00089F, 0.00089F},
          {-0.27942F, -0.27942F, -0.23037F, -0.23037F, 0.94715F, 0.94715F, 0.87574F, 0.87574F,
           0.56464F,  0.56464F,  0.33104F,  0.33104F,  0.18860F, 0.18860F, 0.10649F, 0.10649F,
           0.05996F,  0.05996F,  0.03373F,  0.03373F,  0.01897F, 0.01897F, 0.01067F, 0.01067F,
           0.00600F,  0.00600F,  0.00337F,  0.00337F,  0.00190F, 0.00190F, 0.00107F, 0.00107F},
          {0.65699F, 0.65699F, -0.71372F, -0.71372F, 0.80042F, 0.80042F, 0.94733F, 0.94733F,
           0.64422F, 0.64422F, 0.38355F,  0.38355F,  0.21956F, 0.21956F, 0.12416F, 0.12416F,
           0.06994F, 0.06994F, 0.03935F,  0.03935F,  0.02213F, 0.02213F, 0.01245F, 0.01245F,
           0.00700F, 0.00700F, 0.00394F,  0.00394F,  0.00221F, 0.00221F, 0.00124F, 0.00124F},
          {0.98936F, 0.98936F, -0.97726F, -0.97726F, 0.57432F, 0.57432F, 0.98904F, 0.98904F,
           0.71736F, 0.71736F, 0.43485F,  0.43485F,  0.25029F, 0.25029F, 0.14178F, 0.14178F,
           0.07991F, 0.07991F, 0.04497F,  0.04497F,  0.02530F, 0.02530F, 0.01423F, 0.01423F,
           0.00800F, 0.00800F, 0.00450F,  0.00450F,  0.00253F, 0.00253F, 0.00142F, 0.00142F},
          {0.41212F, 0.41212F, -0.93982F, -0.93982F, 0.29126F, 0.29126F, 0.99956F, 0.99956F,
           0.78333F, 0.78333F, 0.48478F,  0.48478F,  0.28078F, 0.28078F, 0.15936F, 0.15936F,
           0.08988F, 0.08988F, 0.05059F,  0.05059F,  0.02846F, 0.02846F, 0.01600F, 0.01600F,
           0.00900F, 0.00900F, 0.00506F,  0.00506F,  0.00285F, 0.00285F, 0.00160F, 0.00160F},
          {-0.54402F, -0.54402F, -0.61294F, -0.61294F, -0.02068F, -0.02068F, 0.97855F, 0.97855F,
           0.84147F,  0.84147F,  0.53317F,  0.53317F,  0.31098F,  0.31098F,  0.17689F, 0.17689F,
           0.09983F,  0.09983F,  0.05620F,  0.05620F,  0.03162F,  0.03162F,  0.01778F, 0.01778F,
           0.01000F,  0.01000F,  0.00562F,  0.00562F,  0.00316F,  0.00316F,  0.00178F, 0.00178F},
          {-0.99999F, -0.99999F, -0.09728F, -0.09728F, -0.33057F, -0.33057F, 0.92668F, 0.92668F,
           0.89121F,  0.89121F,  0.57988F,  0.57988F,  0.34088F,  0.34088F,  0.19437F, 0.19437F,
           0.10978F,  0.10978F,  0.06182F,  0.06182F,  0.03478F,  0.03478F,  0.01956F, 0.01956F,
           0.01100F,  0.01100F,  0.00619F,  0.00619F,  0.00348F,  0.00348F,  0.00196F, 0.00196F},
          {-0.53657F, -0.53657F, 0.44834F, 0.44834F, -0.60768F, -0.60768F, 0.84558F, 0.84558F,
           0.93204F,  0.93204F,  0.62475F, 0.62475F, 0.37043F,  0.37043F,  0.21178F, 0.21178F,
           0.11971F,  0.11971F,  0.06743F, 0.06743F, 0.03794F,  0.03794F,  0.02134F, 0.02134F,
           0.01200F,  0.01200F,  0.00675F, 0.00675F, 0.00379F,  0.00379F,  0.00213F, 0.00213F},
          {0.42017F, 0.42017F, 0.85588F, 0.85588F, -0.82453F, -0.82453F, 0.73782F, 0.73782F,
           0.96356F, 0.96356F, 0.66765F, 0.66765F, 0.39961F,  0.39961F,  0.22912F, 0.22912F,
           0.12963F, 0.12963F, 0.07304F, 0.07304F, 0.04110F,  0.04110F,  0.02312F, 0.02312F,
           0.01300F, 0.01300F, 0.00731F, 0.00731F, 0.00411F,  0.00411F,  0.00231F, 0.00231F},
          {0.99061F, 0.99061F, 0.99982F, 0.99982F, -0.95961F, -0.95961F, 0.60678F, 0.60678F,
           0.98545F, 0.98545F, 0.70843F, 0.70843F, 0.42840F,  0.42840F,  0.24640F, 0.24640F,
           0.13954F, 0.13954F, 0.07865F, 0.07865F, 0.04426F,  0.04426F,  0.02489F, 0.02489F,
           0.01400F, 0.01400F, 0.00787F, 0.00787F, 0.00443F,  0.00443F,  0.00249F, 0.00249F},
          {0.65029F, 0.65029F, 0.83584F, 0.83584F, -0.99952F, -0.99952F, 0.45660F, 0.45660F,
           0.99749F, 0.99749F, 0.74698F, 0.74698F, 0.45675F,  0.45675F,  0.26359F, 0.26359F,
           0.14944F, 0.14944F, 0.08425F, 0.08425F, 0.04742F,  0.04742F,  0.02667F, 0.02667F,
           0.01500F, 0.01500F, 0.00844F, 0.00844F, 0.00474F,  0.00474F,  0.00267F, 0.00267F},
          {-0.28790F, -0.28790F, 0.41443F, 0.41443F, -0.94031F, -0.94031F, 0.29203F, 0.29203F,
           0.99957F,  0.99957F,  0.78317F, 0.78317F, 0.48465F,  0.48465F,  0.28070F, 0.28070F,
           0.15932F,  0.15932F,  0.08985F, 0.08985F, 0.05057F,  0.05057F,  0.02845F, 0.02845F,
           0.01600F,  0.01600F,  0.00900F, 0.00900F, 0.00506F,  0.00506F,  0.00285F, 0.00285F},
          {-0.96140F, -0.96140F, -0.13462F, -0.13462F, -0.78785F, -0.78785F, 0.11824F, 0.11824F,
           0.99166F,  0.99166F,  0.81688F,  0.81688F,  0.51207F,  0.51207F,  0.29772F, 0.29772F,
           0.16918F,  0.16918F,  0.09545F,  0.09545F,  0.05373F,  0.05373F,  0.03023F, 0.03023F,
           0.01700F,  0.01700F,  0.00956F,  0.00956F,  0.00538F,  0.00538F,  0.00302F, 0.00302F},
          {-0.75099F, -0.75099F, -0.64220F, -0.64220F, -0.55726F, -0.55726F, -0.05928F, -0.05928F,
           0.97385F,  0.97385F,  0.84801F,  0.84801F,  0.53897F,  0.53897F,  0.31465F,  0.31465F,
           0.17903F,  0.17903F,  0.10105F,  0.10105F,  0.05689F,  0.05689F,  0.03200F,  0.03200F,
           0.01800F,  0.01800F,  0.01012F,  0.01012F,  0.00569F,  0.00569F,  0.00320F,  0.00320F},
          {0.14988F, 0.14988F, -0.95200F, -0.95200F, -0.27141F, -0.27141F, -0.23492F, -0.23492F,
           0.94630F, 0.94630F, 0.87645F,  0.87645F,  0.56533F,  0.56533F,  0.33148F,  0.33148F,
           0.18886F, 0.18886F, 0.10664F,  0.10664F,  0.06005F,  0.06005F,  0.03378F,  0.03378F,
           0.01900F, 0.01900F, 0.01068F,  0.01068F,  0.00601F,  0.00601F,  0.00338F,  0.00338F},
          {0.91295F, 0.91295F, -0.96860F, -0.96860F, 0.04136F, 0.04136F, -0.40316F, -0.40316F,
           0.90930F, 0.90930F, 0.90213F,  0.90213F,  0.59113F, 0.59113F, 0.34821F,  0.34821F,
           0.19867F, 0.19867F, 0.11223F,  0.11223F,  0.06320F, 0.06320F, 0.03556F,  0.03556F,
           0.02000F, 0.02000F, 0.01125F,  0.01125F,  0.00632F, 0.00632F, 0.00356F,  0.00356F},
          {0.83666F, 0.83666F, -0.68689F, -0.68689F, 0.35002F, 0.35002F, -0.55868F, -0.55868F,
           0.86321F, 0.86321F, 0.92495F,  0.92495F,  0.61633F, 0.61633F, 0.36482F,  0.36482F,
           0.20846F, 0.20846F, 0.11782F,  0.11782F,  0.06636F, 0.06636F, 0.03734F,  0.03734F,
           0.02100F, 0.02100F, 0.01181F,  0.01181F,  0.00664F, 0.00664F, 0.00373F,  0.00373F},
          {-0.00885F, -0.00885F, -0.19363F, -0.19363F, 0.62398F, 0.62398F, -0.69658F, -0.69658F,
           0.80850F,  0.80850F,  0.94485F,  0.94485F,  0.64092F, 0.64092F, 0.38132F,  0.38132F,
           0.21823F,  0.21823F,  0.12340F,  0.12340F,  0.06951F, 0.06951F, 0.03911F,  0.03911F,
           0.02200F,  0.02200F,  0.01237F,  0.01237F,  0.00696F, 0.00696F, 0.00391F,  0.00391F},
          {-0.84622F, -0.84622F, 0.35926F, 0.35926F, 0.83606F, 0.83606F, -0.81251F, -0.81251F,
           0.74571F,  0.74571F,  0.96177F, 0.96177F, 0.66487F, 0.66487F, 0.39770F,  0.39770F,
           0.22798F,  0.22798F,  0.12898F, 0.12898F, 0.07267F, 0.07267F, 0.04089F,  0.04089F,
           0.02300F,  0.02300F,  0.01293F, 0.01293F, 0.00727F, 0.00727F, 0.00409F,  0.00409F},
          {-0.90558F, -0.90558F, 0.80151F, 0.80151F, 0.96522F, 0.96522F, -0.90282F, -0.90282F,
           0.67546F,  0.67546F,  0.97564F, 0.97564F, 0.68816F, 0.68816F, 0.41395F,  0.41395F,
           0.23770F,  0.23770F,  0.13455F, 0.13455F, 0.07582F, 0.07582F, 0.04267F,  0.04267F,
           0.02400F,  0.02400F,  0.01350F, 0.01350F, 0.00759F, 0.00759F, 0.00427F,  0.00427F},
          {-0.13235F, -0.13235F, 0.99691F, 0.99691F, 0.99866F, 0.99866F, -0.96465F, -0.96465F,
           0.59847F,  0.59847F,  0.98643F, 0.98643F, 0.71075F, 0.71075F, 0.43007F,  0.43007F,
           0.24740F,  0.24740F,  0.14012F, 0.14012F, 0.07897F, 0.07897F, 0.04444F,  0.04444F,
           0.02500F,  0.02500F,  0.01406F, 0.01406F, 0.00791F, 0.00791F, 0.00445F,  0.00445F},
          {0.76256F, 0.76256F, 0.88528F, 0.88528F, 0.93307F, 0.93307F, -0.99605F, -0.99605F,
           0.51550F, 0.51550F, 0.99410F, 0.99410F, 0.73264F, 0.73264F, 0.44605F,  0.44605F,
           0.25708F, 0.25708F, 0.14569F, 0.14569F, 0.08213F, 0.08213F, 0.04622F,  0.04622F,
           0.02600F, 0.02600F, 0.01462F, 0.01462F, 0.00822F, 0.00822F, 0.00462F,  0.00462F},
          {0.95638F, 0.95638F, 0.50099F, 0.50099F, 0.77495F, 0.77495F, -0.99605F, -0.99605F,
           0.42738F, 0.42738F, 0.99862F, 0.99862F, 0.75379F, 0.75379F, 0.46190F,  0.46190F,
           0.26673F, 0.26673F, 0.15125F, 0.15125F, 0.08528F, 0.08528F, 0.04800F,  0.04800F,
           0.02700F, 0.02700F, 0.01518F, 0.01518F, 0.00854F, 0.00854F, 0.00480F,  0.00480F},
          {0.27091F, 0.27091F, -0.03759F, -0.03759F, 0.53997F, 0.53997F, -0.96462F, -0.96462F,
           0.33499F, 0.33499F, 0.99999F,  0.99999F,  0.77419F, 0.77419F, 0.47760F,  0.47760F,
           0.27636F, 0.27636F, 0.15681F,  0.15681F,  0.08843F, 0.08843F, 0.04977F,  0.04977F,
           0.02800F, 0.02800F, 0.01574F,  0.01574F,  0.00885F, 0.00885F, 0.00498F,  0.00498F},
          {-0.66363F, -0.66363F, -0.56459F, -0.56459F, 0.25145F, 0.25145F, -0.90277F, -0.90277F,
           0.23925F,  0.23925F,  0.99820F,  0.99820F,  0.79382F, 0.79382F, 0.49314F,  0.49314F,
           0.28595F,  0.28595F,  0.16236F,  0.16236F,  0.09158F, 0.09158F, 0.05155F,  0.05155F,
           0.02900F,  0.02900F,  0.01631F,  0.01631F,  0.00917F, 0.00917F, 0.00516F,  0.00516F},
          {-0.98803F, -0.98803F, -0.91771F, -0.91771F, -0.06201F, -0.06201F, -0.81245F, -0.81245F,
           0.14112F,  0.14112F,  0.99325F,  0.99325F,  0.81265F,  0.81265F,  0.50854F,  0.50854F,
           0.29552F,  0.29552F,  0.16790F,  0.16790F,  0.09473F,  0.09473F,  0.05332F,  0.05332F,
           0.03000F,  0.03000F,  0.01687F,  0.01687F,  0.00949F,  0.00949F,  0.00533F,  0.00533F},
          {-0.40404F, -0.40404F, -0.98819F, -0.98819F, -0.36933F, -0.36933F, -0.69651F, -0.69651F,
           0.04158F,  0.04158F,  0.98517F,  0.98517F,  0.83067F,  0.83067F,  0.52377F,  0.52377F,
           0.30506F,  0.30506F,  0.17344F,  0.17344F,  0.09787F,  0.09787F,  0.05510F,  0.05510F,
           0.03100F,  0.03100F,  0.01743F,  0.01743F,  0.00980F,  0.00980F,  0.00551F,  0.00551F}}}};
    xt::xarray<float> expected_trans_mat = {
        {{{0.00000F, 1.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F,
           0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F,
           0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F},
          {-1.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F,
           0.00000F,  0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F,
           0.00000F,  0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F,
           0.00000F,  0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F},
          {0.00000F, 0.00000F, 0.00000F, 1.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F,
           0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F,
           0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F},
          {0.00000F, 0.00000F, -1.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F,
           0.00000F, 0.00000F, 0.00000F,  0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F,
           0.00000F, 0.00000F, 0.00000F,  0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F,
           0.00000F, 0.00000F, 0.00000F,  0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F},
          {0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 1.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F,
           0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F,
           0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F},
          {0.00000F, 0.00000F, 0.00000F, 0.00000F, -1.00000F, 0.00000F, 0.00000F, 0.00000F,
           0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F,  0.00000F, 0.00000F, 0.00000F,
           0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F,  0.00000F, 0.00000F, 0.00000F,
           0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F,  0.00000F, 0.00000F, 0.00000F},
          {0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 1.00000F, 0.00000F, 0.00000F, 0.00000F,
           0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F,
           0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F},
          {0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, -1.00000F, 0.00000F,
           0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F,  0.00000F,
           0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F,  0.00000F,
           0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F,  0.00000F},
          {0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 1.00000F, 0.00000F,
           0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F,
           0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F},
          {0.00000F,  0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F,
           -1.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F,
           0.00000F,  0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F,
           0.00000F,  0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F},
          {0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F,
           1.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F,
           0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F},
          {0.00000F, 0.00000F, 0.00000F,  0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F,
           0.00000F, 0.00000F, -1.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F,
           0.00000F, 0.00000F, 0.00000F,  0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F,
           0.00000F, 0.00000F, 0.00000F,  0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F},
          {0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F,
           0.00000F, 0.00000F, 1.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F,
           0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F},
          {0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F,  0.00000F, 0.00000F, 0.00000F,
           0.00000F, 0.00000F, 0.00000F, 0.00000F, -1.00000F, 0.00000F, 0.00000F, 0.00000F,
           0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F,  0.00000F, 0.00000F, 0.00000F,
           0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F,  0.00000F, 0.00000F, 0.00000F},
          {0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F,
           0.00000F, 0.00000F, 0.00000F, 0.00000F, 1.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F,
           0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F},
          {0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F,  0.00000F,
           0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, -1.00000F, 0.00000F,
           0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F,  0.00000F,
           0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F,  0.00000F},
          {0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F,
           0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 1.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F,
           0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F},
          {0.00000F,  0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F,
           0.00000F,  0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F,
           -1.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F,
           0.00000F,  0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F},
          {0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F,
           0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 1.00000F, 0.00000F, 0.00000F,
           0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F},
          {0.00000F, 0.00000F, 0.00000F,  0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F,
           0.00000F, 0.00000F, 0.00000F,  0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F,
           0.00000F, 0.00000F, -1.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F,
           0.00000F, 0.00000F, 0.00000F,  0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F},
          {0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F,
           0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 1.00000F,
           0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F},
          {0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F,  0.00000F, 0.00000F, 0.00000F,
           0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F,  0.00000F, 0.00000F, 0.00000F,
           0.00000F, 0.00000F, 0.00000F, 0.00000F, -1.00000F, 0.00000F, 0.00000F, 0.00000F,
           0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F,  0.00000F, 0.00000F, 0.00000F},
          {0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F,
           0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F,
           0.00000F, 1.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F},
          {0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F,  0.00000F,
           0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F,  0.00000F,
           0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, -1.00000F, 0.00000F,
           0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F,  0.00000F},
          {0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F,
           0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F,
           0.00000F, 0.00000F, 0.00000F, 1.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F},
          {0.00000F,  0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F,
           0.00000F,  0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F,
           0.00000F,  0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F,
           -1.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F},
          {0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F,
           0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F,
           0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 1.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F},
          {0.00000F, 0.00000F, 0.00000F,  0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F,
           0.00000F, 0.00000F, 0.00000F,  0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F,
           0.00000F, 0.00000F, 0.00000F,  0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F,
           0.00000F, 0.00000F, -1.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F},
          {0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F,
           0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F,
           0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 1.00000F, 0.00000F, 0.00000F},
          {0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F,  0.00000F, 0.00000F, 0.00000F,
           0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F,  0.00000F, 0.00000F, 0.00000F,
           0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F,  0.00000F, 0.00000F, 0.00000F,
           0.00000F, 0.00000F, 0.00000F, 0.00000F, -1.00000F, 0.00000F, 0.00000F, 0.00000F},
          {0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F,
           0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F,
           0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 1.00000F},
          {0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F,  0.00000F,
           0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F,  0.00000F,
           0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F,  0.00000F,
           0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, -1.00000F, 0.00000F}}}};

    auto rope_params = ops::build_rope_params(
        /*sequence_length=*/32,
        /*head_dim=*/32);

    EXPECT_TRUE(xt::allclose(expected_cos, core::to_xtensor(rope_params.cos_cache), /*rtol=*/0.01F, /*atol=*/0.03F));
    EXPECT_TRUE(xt::allclose(expected_sin, core::to_xtensor(rope_params.sin_cache), /*rtol=*/0.01F, /*atol=*/0.03F));
    EXPECT_TRUE(xt::allclose(expected_trans_mat, core::to_xtensor(rope_params.trans_mat)));
}

TEST_F(RoPETest, ForwardTest) {
    // Head dim must be a multiple of TILE_WIDTH
    // Head dim must be <= 256

    // Input query tensor
    xt::xarray<float> xq = xt::ones<float>({1, 2, 5, 32});
    xt::xarray<float> expected_xq_out = {
        {{{1.00000F, 1.00000F, 1.00000F, 1.00000F, 1.00000F, 1.00000F, 1.00000F, 1.00000F, 1.00000F, 1.00000F, 1.00000F,
           1.00000F, 1.00000F, 1.00000F, 1.00000F, 1.00000F, 1.00000F, 1.00000F, 1.00000F, 1.00000F, 1.00000F, 1.00000F,
           1.00000F, 1.00000F, 1.00000F, 1.00000F, 1.00000F, 1.00000F, 1.00000F, 1.00000F, 1.00000F, 1.00000F},
          {-0.30078F, 1.38281F, 0.31641F, 1.38281F, 0.64062F, 1.25781F, 0.80859F, 1.16406F,
           0.89844F,  1.09375F, 0.94531F, 1.05469F, 0.96875F, 1.03125F, 0.98438F, 1.01562F,
           0.99219F,  1.00781F, 0.99609F, 1.00781F, 0.99609F, 1.00000F, 1.00000F, 1.00000F,
           1.00000F,  1.00000F, 1.00000F, 1.00000F, 1.00000F, 1.00000F, 1.00000F, 1.00000F},
          {-1.32812F, 0.49414F, -0.47070F, 1.33594F, 0.21484F, 1.39844F, 0.58984F, 1.28906F,
           0.78125F,  1.17969F, 0.87891F,  1.10156F, 0.93359F, 1.06250F, 0.96484F, 1.03906F,
           0.98047F,  1.02344F, 0.98828F,  1.01562F, 0.99609F, 1.00781F, 0.99609F, 1.00781F,
           1.00000F,  1.00000F, 1.00000F,  1.00000F, 1.00000F, 1.00000F, 1.00000F, 1.00000F},
          {-1.13281F, -0.84766F, -1.10938F, 0.87500F, -0.23047F, 1.39844F, 0.35156F, 1.36719F,
           0.66406F,  1.25000F,  0.81641F,  1.15625F, 0.90234F,  1.09375F, 0.94531F, 1.05469F,
           0.96875F,  1.03125F,  0.98438F,  1.01562F, 0.99219F,  1.00781F, 0.99609F, 1.00781F,
           0.99609F,  1.00000F,  1.00000F,  1.00000F, 1.00000F,  1.00000F, 1.00000F, 1.00000F},
          {0.10547F, -1.41406F, -1.40625F, 0.14844F, -0.65234F, 1.25781F, 0.10547F, 1.41406F,
           0.53516F, 1.31250F,  0.75391F,  1.20312F, 0.86719F,  1.11719F, 0.92578F, 1.07031F,
           0.96094F, 1.03906F,  0.97656F,  1.02344F, 0.98828F,  1.01562F, 0.99219F, 1.00781F,
           0.99609F, 1.00781F,  1.00000F,  1.00000F, 1.00000F,  1.00000F, 1.00000F, 1.00000F}},
         {{1.00000F, 1.00000F, 1.00000F, 1.00000F, 1.00000F, 1.00000F, 1.00000F, 1.00000F, 1.00000F, 1.00000F, 1.00000F,
           1.00000F, 1.00000F, 1.00000F, 1.00000F, 1.00000F, 1.00000F, 1.00000F, 1.00000F, 1.00000F, 1.00000F, 1.00000F,
           1.00000F, 1.00000F, 1.00000F, 1.00000F, 1.00000F, 1.00000F, 1.00000F, 1.00000F, 1.00000F, 1.00000F},
          {-0.30078F, 1.38281F, 0.31641F, 1.38281F, 0.64062F, 1.25781F, 0.80859F, 1.16406F,
           0.89844F,  1.09375F, 0.94531F, 1.05469F, 0.96875F, 1.03125F, 0.98438F, 1.01562F,
           0.99219F,  1.00781F, 0.99609F, 1.00781F, 0.99609F, 1.00000F, 1.00000F, 1.00000F,
           1.00000F,  1.00000F, 1.00000F, 1.00000F, 1.00000F, 1.00000F, 1.00000F, 1.00000F},
          {-1.32812F, 0.49414F, -0.47070F, 1.33594F, 0.21484F, 1.39844F, 0.58984F, 1.28906F,
           0.78125F,  1.17969F, 0.87891F,  1.10156F, 0.93359F, 1.06250F, 0.96484F, 1.03906F,
           0.98047F,  1.02344F, 0.98828F,  1.01562F, 0.99609F, 1.00781F, 0.99609F, 1.00781F,
           1.00000F,  1.00000F, 1.00000F,  1.00000F, 1.00000F, 1.00000F, 1.00000F, 1.00000F},
          {-1.13281F, -0.84766F, -1.10938F, 0.87500F, -0.23047F, 1.39844F, 0.35156F, 1.36719F,
           0.66406F,  1.25000F,  0.81641F,  1.15625F, 0.90234F,  1.09375F, 0.94531F, 1.05469F,
           0.96875F,  1.03125F,  0.98438F,  1.01562F, 0.99219F,  1.00781F, 0.99609F, 1.00781F,
           0.99609F,  1.00000F,  1.00000F,  1.00000F, 1.00000F,  1.00000F, 1.00000F, 1.00000F},
          {0.10547F, -1.41406F, -1.40625F, 0.14844F, -0.65234F, 1.25781F, 0.10547F, 1.41406F,
           0.53516F, 1.31250F,  0.75391F,  1.20312F, 0.86719F,  1.11719F, 0.92578F, 1.07031F,
           0.96094F, 1.03906F,  0.97656F,  1.02344F, 0.98828F,  1.01562F, 0.99219F, 1.00781F,
           0.99609F, 1.00781F,  1.00000F,  1.00000F, 1.00000F,  1.00000F, 1.00000F, 1.00000F}}}};

    auto* device = &ttml::autograd::ctx().get_device();

    // Call the RoPE function
    auto rope_params = ops::build_rope_params(
        /*sequence_length=*/5,
        /*head_dim=*/32);
    auto rope_mod = RotaryEmbedding(rope_params);

    auto xq_autograd_tensor = autograd::create_tensor(core::from_xtensor(xq, device));

    auto actual_xq_out = rope_mod(xq_autograd_tensor);

    auto actual_xq_out_xt = core::to_xtensor(actual_xq_out->get_value());

    // Check that outputs match the expected values
    EXPECT_TRUE(xt::allclose(actual_xq_out_xt, expected_xq_out, 2e-1, 2e-1));
}

TEST_F(RoPETest, BackwardTest) {
    // Head dim must be a multiple of TILE_WIDTH
    // Head dim must be <= 256
    // Input query tensor
    xt::xarray<float> xq = xt::ones<float>({1, 2, 5, 32});
    xt::xarray<float> expected_grad = {
        {{{0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F,
           0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F,
           0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F},
          {-0.00238F, 0.00812F, -0.00238F, 0.00433F, -0.00163F, 0.00223F, -0.00100F, 0.00123F,
           -0.00060F, 0.00065F, -0.00035F, 0.00036F, -0.00019F, 0.00020F, -0.00012F, 0.00010F,
           -0.00007F, 0.00005F, -0.00002F, 0.00005F, -0.00002F, 0.00000F, 0.00000F,  0.00000F,
           0.00000F,  0.00000F, 0.00000F,  0.00000F, 0.00000F,  0.00000F, 0.00000F,  0.00000F},
          {0.00319F,  0.01459F, -0.00208F, 0.00922F, -0.00250F, 0.00488F, -0.00177F, 0.00259F,
           -0.00111F, 0.00137F, -0.00065F, 0.00076F, -0.00039F, 0.00042F, -0.00021F, 0.00020F,
           -0.00012F, 0.00015F, -0.00007F, 0.00005F, -0.00005F, 0.00005F, -0.00002F, 0.00000F,
           -0.00002F, 0.00000F, 0.00000F,  0.00000F, 0.00000F,  0.00000F, 0.00000F,  0.00000F},
          {0.01154F,  0.01331F, 0.00077F,  0.01318F, -0.00244F, 0.00769F, -0.00232F, 0.00403F,
           -0.00157F, 0.00212F, -0.00097F, 0.00115F, -0.00055F, 0.00064F, -0.00032F, 0.00036F,
           -0.00019F, 0.00020F, -0.00010F, 0.00010F, -0.00005F, 0.00005F, -0.00002F, 0.00005F,
           -0.00002F, 0.00000F, 0.00000F,  0.00000F, 0.00000F,  0.00000F, 0.00000F,  0.00000F},
          {0.01508F,  0.00558F, 0.00531F,  0.01508F, -0.00159F, 0.01038F, -0.00255F, 0.00562F,
           -0.00194F, 0.00294F, -0.00125F, 0.00154F, -0.00073F, 0.00083F, -0.00043F, 0.00047F,
           -0.00023F, 0.00025F, -0.00014F, 0.00015F, -0.00007F, 0.00010F, -0.00005F, 0.00005F,
           -0.00002F, 0.00005F, -0.00002F, 0.00000F, 0.00000F,  0.00000F, 0.00000F,  0.00000F}},
         {{0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F,
           0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F,
           0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F},
          {-0.00238F, 0.00812F, -0.00238F, 0.00433F, -0.00163F, 0.00223F, -0.00100F, 0.00123F,
           -0.00060F, 0.00065F, -0.00035F, 0.00036F, -0.00019F, 0.00020F, -0.00012F, 0.00010F,
           -0.00007F, 0.00005F, -0.00002F, 0.00005F, -0.00002F, 0.00000F, 0.00000F,  0.00000F,
           0.00000F,  0.00000F, 0.00000F,  0.00000F, 0.00000F,  0.00000F, 0.00000F,  0.00000F},
          {0.00319F,  0.01459F, -0.00208F, 0.00922F, -0.00250F, 0.00488F, -0.00177F, 0.00259F,
           -0.00111F, 0.00137F, -0.00065F, 0.00076F, -0.00039F, 0.00042F, -0.00021F, 0.00020F,
           -0.00012F, 0.00015F, -0.00007F, 0.00005F, -0.00005F, 0.00005F, -0.00002F, 0.00000F,
           -0.00002F, 0.00000F, 0.00000F,  0.00000F, 0.00000F,  0.00000F, 0.00000F,  0.00000F},
          {0.01154F,  0.01331F, 0.00077F,  0.01318F, -0.00244F, 0.00769F, -0.00232F, 0.00403F,
           -0.00157F, 0.00212F, -0.00097F, 0.00115F, -0.00055F, 0.00064F, -0.00032F, 0.00036F,
           -0.00019F, 0.00020F, -0.00010F, 0.00010F, -0.00005F, 0.00005F, -0.00002F, 0.00005F,
           -0.00002F, 0.00000F, 0.00000F,  0.00000F, 0.00000F,  0.00000F, 0.00000F,  0.00000F},
          {0.01508F,  0.00558F, 0.00531F,  0.01508F, -0.00159F, 0.01038F, -0.00255F, 0.00562F,
           -0.00194F, 0.00294F, -0.00125F, 0.00154F, -0.00073F, 0.00083F, -0.00043F, 0.00047F,
           -0.00023F, 0.00025F, -0.00014F, 0.00015F, -0.00007F, 0.00010F, -0.00005F, 0.00005F,
           -0.00002F, 0.00005F, -0.00002F, 0.00000F, 0.00000F,  0.00000F, 0.00000F,  0.00000F}}}};

    auto* device = &ttml::autograd::ctx().get_device();
    auto rope_params = ops::build_rope_params(
        /*sequence_length=*/5,
        /*head_dim=*/32);
    auto rope_mod = modules::RotaryEmbedding(rope_params);

    auto xq_autograd_tensor = autograd::create_tensor(core::from_xtensor(xq, device));

    auto actual_xq_out = rope_mod(xq_autograd_tensor);
    auto target = autograd::create_tensor(core::from_xtensor(xq, device));  // just need ones for mse target, reusing xq

    auto loss = ttml::ops::mse_loss(actual_xq_out, target);
    loss->backward();

    auto actual_grad = core::to_xtensor(xq_autograd_tensor->get_grad());
    EXPECT_TRUE(xt::allclose(actual_grad, expected_grad, 2e-1, 2e-1));
}

}  // namespace ttml::modules::tests
